from multiprocessing.dummy import Pool as ThreadPool
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

dp_ans = "?"
ant_ans = "?"


def plt_reset():
    plt.cla()
    plt.style.use('fivethirtyeight')
    plt.axis('off')  # 去掉坐标轴
    plt.axis('equal')
    plt.xlim(-100, 100)
    plt.xlim(-100, 100)
    plt.title(f"dp_ans:{dp_ans}\n ant_ans:{ant_ans}")


plt.ion()
plt_reset()
plt.pause(3)

'''
第一步：初始化参数
'''
n = 12
Q = 1  # 蚁周模型更新的信息素大小 = Q / ant[k].length
COORDINATE = np.random.randint(-100, 100, [n, 2])
C = set(list(range(n)))


def distance(A, B):                        # 计算坐标之间的距离
    return np.linalg.norm(A-B)

# COORDINATE = np.asarray([[100,100],[100,-100],[-100,100],[-100,-100]])
# theta = np.linspace(0, 2*np.pi, n)
# x = np.cos(theta)*100
# y = np.sin(theta)*100

# x2 = np.cos(theta)*10
# y2 = np.sin(theta)*10
# COORDINATE = np.asarray(list(zip(x, y))+list(zip(x2, y2)))
# n = len(COORDINATE)


'''
dp 求解
'''
dp = np.zeros([1 << n, n])
ans = []


def dfs(u, status, tabu):
    global ans
    global dp
    allowed = C - set(tabu)
    if len(allowed) == 0 and u == 0:
        ans = tabu
        return
    if len(allowed) == 0 and u != 0:
        allowed.add(0)
    for v in list(allowed):
        new_status = status | (1 << v)
        if dp[new_status][v] == 0 or dp[new_status][v] > dp[status][u] + distance(COORDINATE[u], COORDINATE[v]):
            dp[new_status][v] = dp[status][u] + \
                distance(COORDINATE[u], COORDINATE[v])
            dfs(v, new_status, tabu+[v])


dfs(0, 0, [0])
# print(dp)
print(dp[(1 << n)-1][0])
print(ans)

dp_ans = str(dp[(1 << n)-1][0])


def show_the_path(path, title):
    plt_reset()
    plt.scatter(COORDINATE[:, 0], COORDINATE[:, 1])
    for k in range(len(path)-1):
        i = path[k]
        j = path[k+1]
        x = COORDINATE[[i, j]][:, 0]
        y = COORDINATE[[i, j]][:, 1]
        plt.plot(x, y)
        plt.pause(0.2)
    plt.savefig(title)


show_the_path(ans, "dp")

'''
dp求解完成
'''
N_cmax = 1000
N_c = 0
t = 0  # 这里并没用到

const = 1
tau = np.ones([n, n])*const
eta = np.zeros([n, n])  # 先验概率
for i in range(n):
    for j in range(n):
        if i == j:
            eta[i][j] = 0  # 没有自环
        else:
            eta[i][j] = 1 / distance(COORDINATE[i], COORDINATE[j])


def Pk_ij(k, i, alpha=1, beta=1):
    '''
    状态转移公式
    ===
    return 每个allowed对应的Pk
    '''
    allowed = C - set(ant[k].tabu)
    if len(allowed) == 0 and ant[k].tabu[-1] != ant[k].start:
        allowed.add(ant[k].start)
        # 回到起点
    allowed = list(allowed)
    top = np.power(tau[i][allowed], alpha)*np.power(eta[i][allowed], beta)
    bottom = np.sum(np.power(tau[i][allowed], alpha)
                    * np.power(eta[i][allowed], beta))
    return top/bottom, allowed


'''
第二步将m只蚂蚁随机放在n个城市上
'''
m = 100


class Ant(object):
    # 建立“蚂蚁”类

    # tabu
    tabu = []
    # 本次TSP中所走路径的总长度
    length = 0

    def __init__(self, pos):
        self.pos = pos
        self.start = pos


ant = [Ant(np.random.randint(n)) for k in range(m)]


def show_ant_system():
    plt_reset()
    plt.scatter(COORDINATE[:, 0], COORDINATE[:, 1])
    for i in range(n):
        for j in range(n):
            x = COORDINATE[[i, j]][:, 0]
            y = COORDINATE[[i, j]][:, 1]
            plt.plot(x, y, linewidth=tau[i][j], color="blue")
    plt.pause(0.001)


'''
循环次数+1
'''
# pool = ThreadPool(6)
for N_c in range(N_cmax):
    # 清空禁忌表
    for k in range(m):
        ant[k].pos = np.random.randint(n)
        ant[k].tabu = [ant[k].pos]
        ant[k].start = ant[k].pos
        ant[k].length = 0

    for k in range(m):
        '''
        这里的案例是完全图的TSP，所以对每只蚂蚁来说，行动n次就遍历完成了，实现起来可以很简单
        '''
        for c in range(n):
            '''
            根据6.2式计算概率转移，选择具有最大状态转移概率的城市，将蚂蚁移动到该城市，并把该城市记入禁忌表中
            '''
            i = ant[k].pos
            Pk, j = Pk_ij(k, i)
            j = j[np.argmax(Pk)]
            ant[k].pos = j
            ant[k].tabu.append(j)
            ant[k].length += distance(COORDINATE[i], COORDINATE[j])
    # pool.map(ant_go, range(m))
    '''
    蚁周模型更新信息素
    '''
    def update_tau(p=0.1):
        global tau
        #     tau_ij(t+n)的n已经可以忽略
        # 公式6.5
        delta_tau = np.zeros([n, n])
        for k in range(m):
            i = ant[k].start
            for j in ant[k].tabu:
                delta_tau[i][j] += Q / ant[k].length
                i = j
        # 公式6.4
        tau = (1-p)*tau+delta_tau
        print(np.max(tau))

    ant_ans = ant[0].length
    update_tau()
    show_ant_system()
    if N_c % 100 == 0:
        print(ant_ans)
        show_the_path(ant[0].tabu, "ant")
